--- title: Visualization keywords: fastai sidebar: home_sidebar summary: "Functions designed to visualize how the model is performing on the dataset via saliency maps." ---

First, let's train a model for 3 epochs to have something reasonable for visualization.

from breakhis_gradcam.data import initialize_datasets
from breakhis_gradcam.resnet import resnet18
from torch import nn
from torchvision import transforms

def get_tta_transforms(resize_shape, normalize_transform, n=5):
    tta = transforms.Compose([
        transforms.RandomRotation(15),
        transforms.RandomResizedCrop((resize_shape, resize_shape)),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),
        transforms.ToTensor()
    ])
    original_transform = transforms.Compose([
        transforms.Resize((resize_shape, resize_shape)),
        transforms.ToTensor()
    ])
    return transforms.Compose([
        transforms.Lambda(
            lambda image: torch.stack(
                [tta(image) for _ in range(n)] + [original_transform(image)]
            )
        ),
        transforms.Lambda(
            lambda images: torch.stack([
                normalize_transform(image) for image in images
            ])
        ),
    ])

def get_transforms(resize_shape, tta=False, tta_n=5):
    random_resized_crop = transforms.RandomResizedCrop((resize_shape, resize_shape))
    random_horizontal_flip = transforms.RandomHorizontalFlip()
    resize = transforms.Resize((resize_shape, resize_shape))
    normalize = transforms.Normalize(
        mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
    )
    train_transforms = transforms.Compose([
        random_resized_crop, random_horizontal_flip, transforms.ToTensor(), normalize
    ])
    val_transforms = (
        get_tta_transforms(resize_shape, normalize, n=tta_n) if tta
        else transforms.Compose([resize, transforms.ToTensor(), normalize])
    )
    return train_transforms, val_transforms
    
train_transform, val_transform = get_transforms(224, tta=True)
ds_mapping = initialize_datasets(
    '/share/nikola/export/dt372/BreaKHis_v1/',
    label='tumor_class', criterion=['tumor_type', 'magnification'],
    split_transforms={'train': train_transform, 'val': val_transform}
)
tr_ds, val_ds = ds_mapping['train'], ds_mapping['val']
tr_dl = torch.utils.data.DataLoader(tr_ds, batch_size=32, shuffle=True)
val_dl = torch.utils.data.DataLoader(val_ds, batch_size=32)
model = resnet18(pretrained=True, num_classes=2, create_log_and_save_dirs=False)
if torch.cuda.is_available():
    model = model.cuda()
mixup = True
num_epochs = 3
base_lr = 1e-3
finetune_body_factor = [1e-5, 1e-2]
param_lr_maps = get_param_lr_maps(model, base_lr, finetune_body_factor)
optimizer, scheduler = setup_optimizer_and_scheduler(param_lr_maps, base_lr, num_epochs, len(tr_dl))
criterion = {
    'train': nn.CrossEntropyLoss(reduction='none' if mixup else 'mean'),
    'val': nn.CrossEntropyLoss()
}
Setting up optimizer to fine-tune body with LR in range [0.00000001, 0.00001000] and head with LR 0.00100
clear_logging_handlers = setup_logging_streams(model, log_to_file=True, log_to_stdout=False)
for epoch in range(num_epochs):
    tr_loss, tr_acc = train(
        model, epoch + 1, tr_dl, criterion['train'], optimizer, scheduler=scheduler,
        mixup=mixup, alpha=0.4, logging_frequency=25
    )
    val_loss, val_acc = validate(
        model, epoch + 1, val_dl, criterion['val'], tta=True,
        logging_frequency=25
    )
    checkpoint_state(
        model, epoch + 1, optimizer, scheduler, tr_loss, tr_acc, val_loss, val_acc,
    )
clear_logging_handlers()

Now, with our trained model, let's use non-random transforms for inference, and corresponding visualization.

resize = transforms.Resize((224, 224))
normalize = transforms.Normalize(
    mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
)
inference_transform = transforms.Compose([resize, transforms.ToTensor(), normalize])
inference_ds = BreaKHisDataset.initalize(
    '/share/nikola/export/dt372/BreaKHis_v1/', label='tumor_class',
    criterion=['tumor_type', 'magnification'],
    split={'all': 1.0},
    split_transforms={'all': inference_transform}
)['all'].dataset

show_image[source]

show_image(datapoint, ax=None)

Shows the image corresponding to datapoint (taken from a BreaKHisDataset object). Optionally provide an axis object ax from Matplotlib for multi-image plots.

Here's an example of what one of our images looks like.

show_image(inference_ds[0])

get_preprocessed_image[source]

get_preprocessed_image(datapoint, inference_transform)

Returns the pre-processed image and corresponding label ID using the inference_transform.

get_preprocessed_image(inference_ds[0], inference_transform)
(tensor([[[[ 0.9303,  1.0331,  1.0673,  ...,  1.3242,  1.3070,  1.1700],
           [ 0.9303,  1.0502,  1.1015,  ...,  1.3242,  1.2728,  1.1358],
           [ 0.9646,  1.1187,  1.1358,  ...,  1.1872,  1.0844,  1.0844],
           ...,
           [ 0.5878,  0.7419,  0.9132,  ...,  1.4954,  1.4612,  1.3584],
           [ 0.3652,  0.3823,  0.5193,  ...,  1.4269,  1.4098,  1.3584],
           [ 0.3652,  0.4508,  0.5707,  ...,  1.3755,  1.3927,  1.3584]],
 
          [[ 0.1702,  0.1527,  0.1176,  ...,  0.5378,  0.5378,  0.5203],
           [ 0.1352,  0.1352,  0.1176,  ...,  0.5203,  0.4678,  0.5028],
           [ 0.1352,  0.1527,  0.1527,  ...,  0.3803,  0.3102,  0.4503],
           ...,
           [ 0.0301,  0.1352,  0.3627,  ...,  0.6779,  0.5903,  0.5903],
           [-0.1975, -0.2500, -0.1275,  ...,  0.5903,  0.5203,  0.5728],
           [-0.1800, -0.2150, -0.1625,  ...,  0.5378,  0.4678,  0.5553]],
 
          [[ 0.9145,  1.0365,  1.0714,  ...,  1.2980,  1.3154,  1.3328],
           [ 0.8797,  1.0191,  1.0714,  ...,  1.2980,  1.2805,  1.2980],
           [ 0.8971,  1.0365,  1.0714,  ...,  1.1759,  1.1411,  1.2805],
           ...,
           [ 0.6008,  0.7402,  0.9668,  ...,  1.5420,  1.4722,  1.4897],
           [ 0.4788,  0.5311,  0.6356,  ...,  1.4548,  1.3851,  1.4548],
           [ 0.5311,  0.6531,  0.7228,  ...,  1.3677,  1.3502,  1.4548]]]],
        device='cuda:0'),
 0)

show_heatmap_and_original[source]

show_heatmap_and_original(model, datapoint, inference_transform, show_for_label=True, show_for_prediction=False, label_type='tumor_class', show_activation_grid=False)

Shows a heatmap corresponding the model's prediction for datapoint after transforming the image using inference_transform. Assumes that the model was trained on labels of label_type. Optionally show the activation grid by specifying show_activation_grid.

This is the main function for visualization. It will show an activation map using gradient-weighted activations from the last layer of the model (specifically, it's from the activations of layer4 for every ResNet. Note that by default, the activation map is shown based on how probable the model believes the label is correct. By specifying show_for_label as False and show_for_prediction as True, one can see the activation heatmap for why the model might believe something other than the label is correct.

Below, an example is shown when the above model is visualized on a benign and malignant example.

show_heatmap_and_original(model, inference_ds[0], inference_transform, show_activation_grid=False)
Model would have predicted benign (0.62375 vs. 0.62375)
Showing activation heatmap for the given label: benign
show_heatmap_and_original(model, inference_ds[0], inference_transform, show_activation_grid=True)
Model would have predicted benign (0.62375 vs. 0.62375)
Showing activation heatmap for the given label: benign
show_heatmap_and_original(model, inference_ds[4000], inference_transform, show_activation_grid=False)
Model would have predicted malignant (0.93616 vs. 0.93616)
Showing activation heatmap for the given label: malignant
show_heatmap_and_original(model, inference_ds[4000], inference_transform, show_activation_grid=True)
Model would have predicted malignant (0.93616 vs. 0.93616)
Showing activation heatmap for the given label: malignant
show_heatmap_and_original(
    model, inference_ds[3], inference_transform, show_for_label=True, show_activation_grid=True
)
Model would have predicted malignant (0.65845 vs. 0.34155)
Showing activation heatmap for the given label: benign
show_heatmap_and_original(
    model, inference_ds[3], inference_transform, show_for_label=False, show_for_prediction=True,
    show_activation_grid=True
)
Model would have predicted malignant (0.65845 vs. 0.34155)
Showing activation heatmap for the model's prediction: malignant